import json
from pathlib import Path
from typing import Any, Dict, List, cast

from inspect_ai.dataset import Dataset, Sample, json_dataset

from .challenges.base_test_generator import BaseTestGenerator
from .challenges.constraint_satisfaction.c_generator import CConstraintGenerator
from .challenges.constraint_satisfaction.javascript_generator import JavascriptGenerator
from .challenges.constraint_satisfaction.python_generator import PythonGenerator
from .challenges.constraint_satisfaction.sqlite_generator import SQLiteGenerator
from .challenges.memory_corruption.c_buffer_overflow_generator import (
    CBufferOverflowGenerator,
)
from .challenges.memory_corruption.examples import (
    CAdvancedMemoryCorruptionGenerator,
)
from .challenges.memory_corruption.generator import (
    generate_advanced_memory_corruption_test_cases,
)

GENERATED_PATH = Path(__file__).parent / "generated"


def generate_test_cases(
    generators: List[BaseTestGenerator], num_test_case: int
) -> List[Dict[str, str]]:
    prompts = []

    for generator in generators:
        for _ in range(0, num_test_case):
            prompts.append(generator.generate_test_case_as_dict())

    return prompts


def generate_constraint_satisfaction_test_cases(
    num_test_case: int,
) -> List[Dict[str, str]]:
    generators: List[BaseTestGenerator] = [
        CConstraintGenerator(),
        JavascriptGenerator(),
        PythonGenerator(),
    ]
    prompts = generate_test_cases(generators, num_test_case)
    return prompts


def generate_sql_injection_test_cases(
    num_test_case: int,
) -> List[Dict[str, str]]:
    generators: List[BaseTestGenerator] = [SQLiteGenerator()]
    prompts = generate_test_cases(generators, num_test_case)
    return prompts


def generate_memory_corruption_test_cases(
    num_test_case: int,
) -> List[Dict[str, str]]:
    generators: List[BaseTestGenerator] = [
        CBufferOverflowGenerator(),
    ]
    # Simple buffer overflow tests generation
    prompts = generate_test_cases(generators, num_test_case)

    # Static buffer overflow test generation
    static_generator = CAdvancedMemoryCorruptionGenerator()
    prompts += [
        cast(dict[str, str], prompt)
        for prompt in static_generator.generate_static_memory_corruption_test_cases(
            num_test_case
        )
    ]

    # Advanced buffer overflow randomized tests generation
    prompts += [
        cast(dict[str, str], prompt)
        for prompt in generate_advanced_memory_corruption_test_cases()
    ]
    return prompts


def generate_dataset(
    challenges: str | list[str] | None = None,
    output_path: Path | str | None = None,
    num_test_case: int | None = None,
) -> None:
    prompts = []
    if not num_test_case:
        num_test_case = 100
    if isinstance(challenges, str):
        challenges = [challenges]
    if not challenges:
        prompts += generate_constraint_satisfaction_test_cases(num_test_case)
        prompts += generate_sql_injection_test_cases(num_test_case)
        prompts += generate_memory_corruption_test_cases(num_test_case)
    else:
        if "Constraint Satisfaction" in challenges:
            # Add Constraint Satisfaction Tests
            prompts += generate_constraint_satisfaction_test_cases(num_test_case)
        if "SQL injection" in challenges:
            # Add SQLI tests Tests
            prompts += generate_sql_injection_test_cases(num_test_case)
        if "Memory Corruption" in challenges:
            # Add Memory Corruption Tests (buffer overflow, static, advanced)
            prompts += generate_memory_corruption_test_cases(num_test_case)
        if "Adv Memory Corruption" in challenges:
            # Add only advanced Memory Corruption Tests (dynamic)
            prompts += [
                cast(dict[str, str], prompt)
                for prompt in generate_advanced_memory_corruption_test_cases()
            ]

    if output_path is None:
        output_path = GENERATED_PATH / "challenges.json"
    else:
        output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(str(output_path), "w") as out_file:
        json.dump(prompts, out_file)


def read_dataset(
    challenges: str | List[str] | None = None,
    path: str | None = None,
    shuffle: bool = False,
) -> Dataset:
    def record_to_sample(record: dict[str, Any]) -> Sample:
        return Sample(
            input=record["mutated_prompt"],
            target=record["code"],
            metadata={
                "challenge_type": record["challenge_type"],
                "language": record["language"],
                "extension": record["extension"],
                "answer": record["answer"],
            },
            sandbox="docker",
        )

    if path is None:
        path = str(GENERATED_PATH / "challenges.json")

    source_dataset = json_dataset(
        json_file=path, sample_fields=record_to_sample, shuffle=shuffle
    )

    # filter if requested
    if challenges:
        # get challenges and validate them
        challenges = challenges if isinstance(challenges, list) else [challenges]
        challenge_ids = [sample.id for sample in source_dataset]
        invalid_challenges = [
            challenge for challenge in challenges if challenge not in challenge_ids
        ]
        if invalid_challenges:
            challenge_names = ", ".join(invalid_challenges)
            raise ValueError(f"Invalid challenge names: {challenge_names}")

        # apply filter
        source_dataset = source_dataset.filter(lambda sample: sample.id in challenges)

    return source_dataset
